from functools import lru_cache

import numpy as np

from ..base import Property
from ..updater import Updater
from ..types.prediction import MeasurementPrediction
from ..types.update import Update


class AlphaBetaUpdater(Updater):
    r"""Conceptually, the :math:`\alpha-\beta` filter is similar to its Kalman cousins in that it
    operates recursively over predict and update steps. It assumes that a state vector is
    decomposable into quantities and the rates of change of those quantities. We refer to these as
    position :math:`p` and velocity :math:`v` respectively, though they aren't confined to
    locations in space. If the interval from :math:`t_{k-1} \rightarrow t_k` is :math:`\Delta T`,
    and at :math:`k`, we can gain a (noisy) measurement of the position, :math:`p^z_k`.

    The recursion proceeds as:

    * Predict

    .. math::

        p_{k|k-1} &= p_{k-1} + \Delta T v_{k-1}

        v_{k|k-1} &= v_{k-1}

    * Update

    .. math::

        s_k &= p^z_k - p_{k|k-1} \: (\mathrm{innovation})

        p_k &= p_{k|k-1} + \alpha s_k

        v_k &= v_{k|k-1} + \frac{\beta}{\Delta T} s_k

    The :math:`\alpha` and :math:`\beta` parameters which give the filter its name are small,
    :math:`0 < \alpha < 1` and :math:`0 < \beta \leq 2`. Colloquially, the larger the values of the
    parameters, the more influence the measurements have over the transition model; :math:`\beta`
    is usually much smaller than :math:`\alpha`.

    As the prediction is just the application of a constant velocity model, there is no
    :math:`\alpha-\beta` predictor provided in Stone Soup. It is assumed that the predictions
    passed to the hypothesis have been generated by a constant velocity model. Any application of a
    control model is also assumed to have taken place during the prediction stage.

    This class assumes the velocity is in units of the length per second. If different units are
    required, scale the prior appropriately.

    The measurement model used should be linear and a measurement model such that it provides a
    'mapping' to :math:`p` via the :attr:`mapping` tuple and a binary measurement matrix which
    returns :math:`p`. This isn't checked.

    """

    alpha: float = Property(doc="The alpha parameter. Controls the weight given to the "
                                "measurements over the transition model.")
    beta: float = Property(doc="The beta parameter. Controls the amount of variation allowed in "
                               "the velocity component.")

    vmap: np.ndarray = Property(default=None, doc="Binary map of the velocity elements in the "
                                                  "state vector. If left default, the class will "
                                                  "assume that the velocity elements interleave "
                                                  "the position elements in the state vector.")

    @lru_cache()
    def predict_measurement(self, prediction, measurement_model=None, **kwargs):
        """Return the predicted measurement

        Parameters
        ----------
        prediction : :class:`~.StatePrediction`
            The state prediction

        Returns
        -------
         : :class:`~.StateVector`
            The predicted measurement
        """
        # This necessary if predict_measurement called on its own
        measurement_model = self._check_measurement_model(measurement_model)
        pred_meas = measurement_model.matrix(**kwargs) @ prediction.state_vector
        return MeasurementPrediction.from_state(prediction, pred_meas)

    def update(self, hypothesis, time_interval, **kwargs):
        """Calculate the inferred state following update

        Parameters
        ----------
        hypothesis : :class:`~.Hypothesis`
            A hypothesis associates a measurement with a prediction
        time_interval : :class:`~.timedelta`
            The time interval over which the prediction has been made.

        Returns
        -------
         : :class:`~.StateUpdate`
            The updated state
        """
        out_statevector = hypothesis.prediction.state_vector.copy()

        # Check for the measurement_model in the measurement, if not present use the one in this
        # updater
        measurement_model = hypothesis.measurement.measurement_model
        measurement_model = self._check_measurement_model(measurement_model)

        # Check for a measurement prediction in the hypothesis
        if hypothesis.measurement_prediction is None:
            pred_meas = self.predict_measurement(hypothesis.prediction,
                                                 measurement_model=measurement_model)
        else:
            pred_meas = hypothesis.measurement_prediction

        pmap = np.array(measurement_model.mapping)
        if self.vmap is None:
            vmap = pmap + 1
        else:
            vmap = self.vmap

        innovation = hypothesis.measurement.state_vector - pred_meas.state_vector
        out_statevector[pmap] = hypothesis.prediction.state_vector[pmap] + self.alpha * innovation
        out_statevector[vmap] = hypothesis.prediction.state_vector[vmap] +\
            (self.beta / time_interval.total_seconds()) * innovation

        return Update.from_state(hypothesis.prediction, out_statevector,
                                 timestamp=hypothesis.measurement.timestamp, hypothesis=hypothesis)
